{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Customize Graph Convolution using Message Passing APIs\n",
    "\n",
    "In previous sessions, we have learned using the built-in [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv) to build a multi-layer graph neural network. However, sometimes one desires to invent a new way of aggregating neighbor information. DGL's message passing APIs are designed for this scenario.\n",
    "\n",
    "In this tutorial, you will learn:\n",
    "* What is under the hood of the `nn.SAGEConv` module in DGL?\n",
    "* DGL's message passing APIs.\n",
    "* Design a new graph convolution module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tutorial_utils import setup_tf\n",
    "setup_tf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A gentle explanation of the `SAGEConv` module\n",
    "\n",
    "Recall that a `SAGEConv` module aggregates neighbor information and generates new node representations as follows:\n",
    "\n",
    "\n",
    "$$\n",
    "h_{\\mathcal{N}(v)}^k\\leftarrow \\text{AGGREGATE}_k\\{h_u^{k-1},\\forall u\\in\\mathcal{N}(v)\\}\n",
    "$$\n",
    "\n",
    "$$\n",
    "h_v^k\\leftarrow \\text{ReLU}\\left(W^k\\cdot \\text{CONCAT}(h_v^{k-1}, h_{\\mathcal{N}(v)}^k) \\right)\n",
    "$$\n",
    "\n",
    "Here is its implementation in DGL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl.function as fn\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "class SAGEConv(layers.Layer):\n",
    "    \"\"\"Graph convolution module used by the GraphSAGE model.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    in_feat : int\n",
    "        Input feature size.\n",
    "    out_feat : int\n",
    "        Output feature size.\n",
    "    \"\"\"\n",
    "    def __init__(self, in_feat, out_feat):\n",
    "        super(SAGEConv, self).__init__()\n",
    "        # A linear submodule for projecting the input and neighbor feature to the output.\n",
    "        self.linear = layers.Dense(out_feat)\n",
    "    \n",
    "    def call(self, g, h):\n",
    "        \"\"\"Forward computation\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        g : Graph\n",
    "            The input graph.\n",
    "        h : Tensor\n",
    "            The input node feature.\n",
    "        \"\"\"\n",
    "        # All the `ndata` set within a local scope will be automatically popped out.\n",
    "        with g.local_scope():\n",
    "            g.ndata['h'] = h\n",
    "            # update_all is a message passing API.\n",
    "            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))\n",
    "            h_neigh = g.ndata['h_neigh']\n",
    "            h_total = tf.concat([h, h_neigh], dim=1)\n",
    "            return tf.nn.relu(self.linear(h_total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The central piece in this code is the `g.update_all` function, which gathers and averages the neighbor features. There are three concepts here:\n",
    "* Message function `fn.copy_u('h', 'm')` that copies the node feature under name `'h'` as *messages* sent to neighbors.\n",
    "* Reduce function `fn.mean('m', 'h_neigh')` that averages all the received messages under name `'m'` and saves the result as a new node feature `'h_neigh'`.\n",
    "* `update_all` tells DGL to trigger the message and reduce functions for all the nodes and edges."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Message passing and GNNs\n",
    "\n",
    "The `update_all` is one of the **message passing APIs** in DGL, inspired by the Message Passing Neural Network proposed by [Gilmer et al.](https://arxiv.org/abs/1704.01212) Essentailly, they found many GNN models can fit into the following framework:\n",
    "\n",
    "$$\n",
    "m_{u\\sim v}^{(l)} = M^{(l)}\\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\\sim v}^{(l-1)}\\right)\n",
    "$$\n",
    "\n",
    "$$\n",
    "m_{v}^{(l)} = \\sum_{u\\in\\mathcal{N}(v)}m_{u\\sim v}^{(l)}\n",
    "$$\n",
    "\n",
    "$$\n",
    "h_v^{(l)} = U^{(l)}\\left(h_v^{(l-1)}, m_v^{(l)}\\right)\n",
    "$$\n",
    "\n",
    ", where the $M^{(l)}$ is called message function and the $\\sum$ is the reduce function. In DGL, we provide many built-in message and reduce functions under the `dgl.function` package.\n",
    "\n",
    "![api](../asset/dgl-mp.png)\n",
    "\n",
    "You can find more details in [the API doc](https://docs.dgl.ai/api/python/function.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DGL's message passing APIs allow one to quickly implement new graph convolution modules. For example, the following implements a new `SAGEConv` that aggregates neighbor representations using a weighted average."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SAGEConv(layers.Layer):\n",
    "    \"\"\"Graph convolution module used by the GraphSAGE model.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    in_feat : int\n",
    "        Input feature size.\n",
    "    out_feat : int\n",
    "        Output feature size.\n",
    "    \"\"\"\n",
    "    def __init__(self, in_feat, out_feat):\n",
    "        super(SAGEConv, self).__init__()\n",
    "        # A linear submodule for projecting the input and neighbor feature to the output.\n",
    "        self.linear = layers.Dense(out_feat)\n",
    "    \n",
    "    def forward(self, g, h, w):\n",
    "        \"\"\"Forward computation\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        g : Graph\n",
    "            The input graph.\n",
    "        h : Tensor\n",
    "            The input node feature.\n",
    "        w : Tensor\n",
    "            The edge weight.\n",
    "        \"\"\"\n",
    "        # All the `ndata` set within a local scope will be automatically popped out.\n",
    "        with g.local_scope():\n",
    "            g.ndata['h'] = h\n",
    "            g.edata['w'] = w\n",
    "            # update_all is a message passing API.\n",
    "            g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.mean('m', 'h_neigh'))\n",
    "            h_neigh = g.ndata['h_neigh']\n",
    "            h_total = tf.concat([h, h_neigh], axis=1)\n",
    "            return tf.nn.relu(self.linear(h_total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Even more customization by user-defined function\n",
    "\n",
    "DGL allows user-defined message and reduce function for the maximal expressiveness. Here is a user-defined message function that is equivalent to `fn.u_mul_e('h', 'w', 'm')`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def u_mul_e_udf(edges):\n",
    "    return {'m' : edges.src['h'] * edges.data['w']}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recap\n",
    "\n",
    "* `dgl.nn` provides many popular modules for quick bootstrap.\n",
    "* Using the built-in message and reduce functions in `dgl.function` to customize a new NN module.\n",
    "* User-defined function provides even more flexibility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3.6.8",
   "language": "python",
   "name": "3.6.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
